Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] Add remat (gradient checkpointing) #17843

Merged
merged 11 commits into from
Jul 1, 2022

Conversation

sanchit-gandhi
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi commented Jun 23, 2022

What does this PR do?

Adds gradient checkpointing in Flax (c.f. #17399). The API currently takes the form of a method:

from transformers import BertConfig, FlaxBertModel

model = FlaxBertModel(BertConfig())
model.enable_gradient_checkpointing()

Note: checkpointing has currently only been implemented for FlaxBert. Implementing for all Flax models is a TODO.

TODO:

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

cc @borisdayma

@sanchit-gandhi sanchit-gandhi added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Jun 23, 2022
Comment on lines -585 to +601
layer_head_mask=head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
head_mask[i] if head_mask is not None else None,
encoder_hidden_states,
encoder_attention_mask,
init_cache,
deterministic,
output_attentions,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: remat does not support kwargs, hence the need to change to args

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 23, 2022

The documentation is not available anymore as the PR was closed or merged.

@borisdayma
Copy link
Contributor

Is there an inconvenient in adding it to all layers?

In my case I used it only on transformers blocks (attention + feed forward).

@sanchit-gandhi
Copy link
Contributor Author

sanchit-gandhi commented Jun 27, 2022

Is there an inconvenient in adding it to all layers?

By wrapping FlaxBertLayer in a remat operation, each Bert layer (attention, intermediate FF, final FF + optional cross-attention layers) has remat applied to it:

FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy)

We then use this remat'd layer to construct the Transformer block (layers collection):
self.layers = [
FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]

Meaning that each component of the Bert layer is checkpointed, and that all Bert layers in the Transformer block (layers collection) are checkpointed.

Would you like to see remat on the embeddings and pooler layers too? Imagine this wouldn't make a huge difference to performance at train time vs just checkpointing the entire Transformer block?

@sanchit-gandhi sanchit-gandhi removed the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Jun 27, 2022
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Also cc @younesbelkada

We could also look into implementing this for OPT and BLOOM in Flax :-) Great job @sanchit-gandhi

Only feedback from my side would be to remove the option to overwrite the policy (also since we don't test it)


def setup(self):
if self.gradient_checkpointing:
FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy)
FlaxBertLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy)

(nit) I'd just leave the naming as is. IMO it's easier to read the code and compare to PyTorch this way, but also happy to leave as is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remat prevented re-use of the class name FlaxBertLayer:
google/flax#2251
Can re-name in a follow-up PR if we find a workaround!

Comment on lines -585 to +601
layer_head_mask=head_mask[i] if head_mask is not None else None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
deterministic=deterministic,
output_attentions=output_attentions,
head_mask[i] if head_mask is not None else None,
encoder_hidden_states,
encoder_attention_mask,
init_cache,
deterministic,
output_attentions,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok!

@@ -617,9 +628,16 @@ def __call__(
class FlaxBertEncoder(nn.Module):
config: BertConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
gradient_checkpointing: bool = False
remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there multiple policies? Would one every use another one then the default one? Wondering if allowing this parameter to be customizable might be a bit scary for the user and make the whole functionality less understandable. Think I'd prefer to just use the default here and not allow the user to configure it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The full list of remat policies can be found here. They dictate whether output value(s) are saved as a residual or whether they must be recomputed in the (co)tangent computation.

The advice for selecting an appropriate remat policy is empirically driven: try them all and see what works best! On paper, dot_with_no_batch_dims should work best for Transformer architectures, and indeed was the preference for T5x. However, for the Seq2Seq project, I found the default policy to be optimal!

I'm in agreement that including the remat_policy as an attribute is probably too heavy and clutters the code. It's straightforward to add one's own policy choice by overriding the policy arg to the remat method, and users who wish to do so can easily access this.

@borisdayma
Copy link
Contributor

Would you like to see remat on the embeddings and pooler layers too? Imagine this wouldn't make a huge difference to performance at train time vs just checkpointing the entire Transformer block?

No actually I thought it was on all layers but the way you did is great!

@patrickvonplaten
Copy link
Contributor

Cool! Once the tests are green, happy to merge it here :-)

@sanchit-gandhi sanchit-gandhi merged commit 485bbe7 into huggingface:main Jul 1, 2022
@sanchit-gandhi sanchit-gandhi deleted the flax-remat branch July 1, 2022 17:33
viclzhu pushed a commit to viclzhu/transformers that referenced this pull request Jul 18, 2022
* [Flax] Add remat (gradient checkpointing)

* fix variable naming in test

* flip: checkpoint using a method

* fix naming

* fix class naming

* apply PVP's suggestions from code review

* make fix-copies

* fix big-bird, electra, roberta

* cookie-cutter

* fix flax big-bird

* move test to common
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants